''''
Answer question Q4: How effective is the optimal activities compared to non-optimal ones in their effect on behavior?

Author: Meng Zhang
Date: January 2024
Disclaimer: adapted from the analysis code https://doi.org/10.4121/22153898.v1

Input: RL_trasition_weighted_reward.csv
       all_states.csv
Output: Figure 6
'''

import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
# To compute RL-related things
import Calculate_Q_values  as cal_q
import Utils as util

# There are 14 actions
NUM_ACTIONS = 14
df_weighted, mean, min, max = util.weighted_sum_of_reward__for_transitions(0.5)
data = pd.read_csv("RL_trasition_weighted_reward.csv", converters={'Binary_State': eval,'Binary_State_Next_Session': eval})
all_people = list(set(data['rand_id'].tolist()))
NUM_PEOPLE = len(all_people)
print("Total number of samples: " + str(len(data)) + ".")
print("Total number of people: " + str(NUM_PEOPLE) + ".")

df_all_states = pd.read_csv("all_states.csv",
                           converters={'Binary_State': eval})

print("Total number of states: " + str(len(df_all_states)) + ".")



#### FEATURE SELECTION
NUM_FEAT_TO_SELECT = 3
OUTPUT_LOWER = -1
OUTPUT_HIGHER = 1
CANDIDATE_FEATURES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
data_train = data.copy(deep=True)
reward_mean = mean
map_to_rewards = util.get_map_effort_reward(reward_mean, OUTPUT_LOWER, OUTPUT_HIGHER, min, max)
df_feat = data_train.drop(columns=['rand_id', 'session_num'])
data_feat = df_feat.values.tolist()
feat_sel = cal_q.feature_selection(data_feat, mean, min, max, CANDIDATE_FEATURES, NUM_FEAT_TO_SELECT)
print("Features selected:", feat_sel, "Weighted_mean", reward_mean)



#### AVERAGE REWARD OVER TIME
### Compute Q-values and dynamics
data_train_q = data_train[["Binary_State", "Binary_State_Next_Session", "cluster_new_index", "weighted_reward"]].values.tolist()
q_values, reward_func, trans_func, _ = cal_q.compute_q_vals_dynamics(data_train_q,
                                                        reward_mean,
                                                        min,
                                                        max,
                                                        feat_sel,
                                                        num_act = NUM_ACTIONS)

# derive the best, worst, average policy and corresponding reward and transition function.
opt_policy = [np.argmax(q_values[state]) for state in range(len(q_values))]
print("opt policy", opt_policy)
opt_policy, _ = util.get_opt_policy_without_repeat(q_values)
print("opt policy remove repeatition", opt_policy)
trans_func_opt_policy = np.array([trans_func[state][opt_policy[state]] for state in range(2**NUM_FEAT_TO_SELECT)])
reward_func_opt_policy = np.array([reward_func[state][opt_policy[state]] for state in range(2**NUM_FEAT_TO_SELECT)])

worst_policy = [np.argmin(q_values[state]) for state in range(len(q_values))]
trans_func_worst_policy = np.array([trans_func[state][worst_policy[state]] for state in range(2**NUM_FEAT_TO_SELECT)])
reward_func_worst_policy = np.array([reward_func[state][worst_policy[state]] for state in range(2**NUM_FEAT_TO_SELECT)])

# Let's compute the average of the transition functions for the 5 actions in a state.
# This is if we choose each action 20% of the time. Of course this is a theoretical construct
trans_func_avg_policy = np.array([sum(trans_func[state])/NUM_ACTIONS for state in range(2**NUM_FEAT_TO_SELECT)])
reward_func_avg_policy = np.array([sum(reward_func[state])/NUM_ACTIONS for state in range(2**NUM_FEAT_TO_SELECT)])

print("Optimal policy in each state:", opt_policy)
print("Worst policy in each state:", worst_policy)


### Compute percentiles from actual data
efforts_s1 = data[data["session_num"] < 2]["weighted_reward"].to_list()
rewards_s1 = util.map_efforts_to_rewards(efforts_s1, map_to_rewards)
perc_nums = [0, 10, 20, 30, 40, 50, 70, 80, 100]
percentiles = np.percentile(rewards_s1, perc_nums,
                            interpolation = 'linear')
print("percentiles", percentiles)

### State distribution from session 1 as the initial population
all_states_s1 = df_all_states[df_all_states['session_num'] == 1]
all_states = [list(i) for i in itertools.product([0, 1], repeat=NUM_FEAT_TO_SELECT)]
all_states_count = np.zeros(2 ** NUM_FEAT_TO_SELECT)

for p in range(len(all_states_s1)):
    state = list(np.take(all_states_s1.iloc[p]["Binary_State"], feat_sel))
    state_idx = all_states.index(state)
    all_states_count[state_idx] += 1

all_states_frac = all_states_count / sum(all_states_count)

print("Fraction of people in each state in session 1:", np.round(all_states_frac, 2))


### Average reward per transition for different numbers of time steps
num_steps_list = [1, 2, 3, 5, 10, 20, 30, 50, 100]
initial_pop = all_states_frac
initial_pop_size = sum(initial_pop)

rew_opt_policy_list = []
rew_worst_policy_list = []
rew_avg_policy_list = []


for num_steps in num_steps_list:
    print("\nNumber of time steps:", num_steps)

    pop_opt_policy = initial_pop
    pop_worst_policy = initial_pop
    pop_avg_policy = initial_pop

    rew_opt_policy = 0
    rew_worst_policy = 0
    rew_avg_policy = 0

    for t in range(num_steps):
        trans_time_rew = trans_func_opt_policy * reward_func_opt_policy  # element-wise multiplication
        rew_opt_policy += sum((trans_time_rew.T).dot(pop_opt_policy))  # total reward for transitions
        pop_opt_policy = (trans_func_opt_policy.T).dot(pop_opt_policy)  # new population

        trans_time_rew = trans_func_worst_policy * reward_func_worst_policy  # element-wise multiplication
        rew_worst_policy += sum((trans_time_rew.T).dot(pop_worst_policy))  # total reward for transitions
        pop_worst_policy = (trans_func_worst_policy.T).dot(pop_worst_policy)  # new population

        trans_time_rew = trans_func_avg_policy * reward_func_avg_policy  # element-wise multiplication
        rew_avg_policy += sum((trans_time_rew.T).dot(pop_avg_policy))  # total reward for transitions
        pop_avg_policy = (trans_func_avg_policy.T).dot(pop_avg_policy)  # new population

    print("Opt policy:", round(rew_opt_policy / (num_steps * initial_pop_size), 2))
    rew_opt_policy_list.append(rew_opt_policy / (num_steps * initial_pop_size))
    print("Worst policy:", round(rew_worst_policy / (num_steps * initial_pop_size), 2))
    rew_worst_policy_list.append(rew_worst_policy / (num_steps * initial_pop_size))
    print("Avg policy:", round(rew_avg_policy / (num_steps * initial_pop_size), 2))
    rew_avg_policy_list.append(rew_avg_policy / (num_steps * initial_pop_size))


### People with reward in lowest 25%-percentile from the first session as starting population
# Determine initial population
all_states = [list(i) for i in itertools.product([0, 1], repeat=NUM_FEAT_TO_SELECT)]
all_states_count = np.zeros(2 ** NUM_FEAT_TO_SELECT)
all_people_s1s2 = data[data["session_num"] == 1]
reward_25_perc = np.percentile(rewards_s1, 25, interpolation='linear')

for p in range(len(all_people_s1s2)):
    state = list(np.take(all_people_s1s2.iloc[p]["Binary_State"], feat_sel))
    reward = util.map_efforts_to_rewards([all_people_s1s2.iloc[p]['weighted_reward']], map_to_rewards)[0]
    if reward <= reward_25_perc:
        state_idx = all_states.index(state)
        all_states_count[state_idx] += 1

all_states_frac_l25 = all_states_count / sum(all_states_count)
print("Fraction of people in each state in session 1 within lowest 25% of reward:", np.round(all_states_frac_l25, 2))
# Average reward per transition over time
num_steps_list = [1, 2, 3, 5, 10, 20, 30, 50, 100]
initial_pop = all_states_frac_l25
initial_pop_size = sum(initial_pop)

rew_opt_policy_list_l25 = []
rew_worst_policy_list_l25 = []
rew_avg_policy_list_l25 = []

for num_steps in num_steps_list:
    print("\nNumber of time steps:", num_steps)

    pop_opt_policy = initial_pop
    pop_worst_policy = initial_pop
    pop_avg_policy = initial_pop

    rew_opt_policy = 0
    rew_worst_policy = 0
    rew_avg_policy = 0

    for t in range(num_steps):
        trans_time_rew = trans_func_opt_policy * reward_func_opt_policy  # element-wise multiplication
        rew_opt_policy += sum((trans_time_rew.T).dot(pop_opt_policy))  # total reward for transitions
        pop_opt_policy = (trans_func_opt_policy.T).dot(pop_opt_policy)  # new population

        trans_time_rew = trans_func_worst_policy * reward_func_worst_policy  # element-wise multiplication
        rew_worst_policy += sum((trans_time_rew.T).dot(pop_worst_policy))  # total reward for transitions
        pop_worst_policy = (trans_func_worst_policy.T).dot(pop_worst_policy)  # new population

        trans_time_rew = trans_func_avg_policy * reward_func_avg_policy  # element-wise multiplication
        rew_avg_policy += sum((trans_time_rew.T).dot(pop_avg_policy))  # total reward for transitions
        pop_avg_policy = (trans_func_avg_policy.T).dot(pop_avg_policy)  # new population

    print("Opt policy:", round(rew_opt_policy / (num_steps * initial_pop_size), 2))
    rew_opt_policy_list_l25.append(rew_opt_policy / (num_steps * initial_pop_size))
    print("Worst policy:", round(rew_worst_policy / (num_steps * initial_pop_size), 2))
    rew_worst_policy_list_l25.append(rew_worst_policy / (num_steps * initial_pop_size))
    print("Avg policy:", round(rew_avg_policy / (num_steps * initial_pop_size), 2))
    rew_avg_policy_list_l25.append(rew_avg_policy / (num_steps * initial_pop_size))



#### CREATE FIGURE 6
sns.set()
sns.set_style("white")

med_fontsize = 22
small_fontsize = 18
extrasmall_fontsize = 15
sns.set_context("paper", rc={"font.size":small_fontsize,"axes.titlesize":med_fontsize,"axes.labelsize":med_fontsize,
                            'xtick.labelsize':small_fontsize, 'ytick.labelsize':small_fontsize,
                            'legend.fontsize':extrasmall_fontsize,'legend.title_fontsize': extrasmall_fontsize})

plt.figure(figsize=(10,5))

fig_lower_bound = -0.5  # y-axis lower limit for figure
fig_higher_bound = 0.5  # y-axis upper limit for figure

x_vals = np.arange(len(num_steps_list))

# Plot average reward for the three policies
plt.plot(x_vals, rew_opt_policy_list, color = 'deepskyblue', label = "Optimal policy")
plt.plot(x_vals, rew_avg_policy_list, color = 'slategray', label = "Average policy")
plt.plot(x_vals, rew_worst_policy_list, color = 'black', label = "Worst policy")
plt.plot(x_vals, rew_opt_policy_list_l25, color = 'deepskyblue', linestyle = 'dashed')
plt.plot(x_vals, rew_avg_policy_list_l25, color = 'slategray', linestyle = 'dashed')
plt.plot(x_vals, rew_worst_policy_list_l25, color = 'black', linestyle = 'dashed')

plt.ylabel("Mean Reward per Transition")
plt.xlabel("Number of Time Steps")

plt.ylim([fig_lower_bound, fig_higher_bound])
plt.xticks(x_vals, num_steps_list)

# Annotations for reward lines
plt.text(5.5, 0.20, s = "Optimal Policy", color = 'deepskyblue')
plt.text(5.5, 0.05, s = "Average Policy", color = 'slategray')
plt.text(5.5, -0.095, s = "Worst Policy", color = 'black')

plt.savefig("Figures/Figure_6.pdf", dpi=1500,
            bbox_inches='tight', pad_inches=0)

